--- external/eg3d/training/volumetric_rendering/ray_marcher.py	2023-03-14 23:34:32.199358815 +0000
+++ external_reference/eg3d/training/volumetric_rendering/ray_marcher.py	2023-03-14 23:28:05.416774738 +0000
@@ -22,11 +22,12 @@
         super().__init__()
 
 
-    def run_forward(self, colors, densities, depths, rendering_options):
+    def run_forward(self, colors, densities, depths, rendering_options, noise):
         deltas = depths[:, :, 1:] - depths[:, :, :-1]
         colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
         densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
         depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+        noise_mid = (noise[:, :, :-1] + noise[:, :, 1:]) / 2
 
 
         if rendering_options['clamp_mode'] == 'softplus':
@@ -34,12 +35,23 @@
         else:
             assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
 
+        # multiply weights by a decay factor that depends on depths
+        # adjusts for popup effect at horizon
+        if rendering_options['decay_start'] is not None:
+            decay_start = rendering_options['decay_start']
+            ray_end = rendering_options['ray_end']
+            decay_weight = torch.clamp((depths_mid - decay_start) / (ray_end - decay_start), 0, 1)
+            decay_weight = 1-decay_weight
+        else:
+            decay_weight = torch.ones_like(depths_mid)
+
         density_delta = densities_mid * deltas
 
         alpha = 1 - torch.exp(-density_delta)
 
         alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
         weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+        weights = weights * decay_weight
 
         composite_rgb = torch.sum(weights * colors_mid, -2)
         weight_total = weights.sum(2)
@@ -49,15 +61,21 @@
         composite_depth = torch.nan_to_num(composite_depth, float('inf'))
         composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
 
+        # disparity map
+        composite_disp = torch.sum(weights * 1/depths_mid, -2)
+        composite_disp = torch.nan_to_num(composite_disp, float('inf'))
+        composite_disp = torch.clamp(composite_disp, 0., 1.)
+
         if rendering_options.get('white_back', False):
             composite_rgb = composite_rgb + 1 - weight_total
 
         composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
 
-        return composite_rgb, composite_depth, weights
+        composite_noise = torch.sum(weights * noise_mid, -2)
 
+        return composite_rgb, composite_depth, composite_disp, weights, composite_noise
 
-    def forward(self, colors, densities, depths, rendering_options):
-        composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
 
-        return composite_rgb, composite_depth, weights
\ No newline at end of file
+    def forward(self, colors, densities, depths, rendering_options, noise):
+        composite_rgb, composite_depth, composite_disp, weights, composite_noise = self.run_forward(colors, densities, depths, rendering_options, noise)
+        return composite_rgb, composite_depth, composite_disp, weights, composite_noise
